/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include "../../../../../common/data_utils.h"
#include "aclnn_layernorm_grad_custom.h"
#include "acl/acl_rt.h"
#include "acl/acl.h"
#include <stdio.h>
#include <stdlib.h>

aclrtStream CreateStream(int32_t device)
{
    if (aclInit(nullptr) != ACL_SUCCESS) {
        printf("acl init failed\n");
        return nullptr;
    }
    if (aclrtSetDevice(device) != ACL_SUCCESS) {
        printf("Set device failed\n");
        CHECK_ACL(aclFinalize());
        return nullptr;
    }
    aclrtStream stream = nullptr;
    if (aclrtCreateStream(&stream) != ACL_SUCCESS) {
        printf("Create stream failed\n");
        CHECK_ACL(aclFinalize());
        return nullptr;
    }
    return stream;
}

void DestroyStream(aclrtStream stream, int32_t device)
{
    CHECK_ACL(aclrtDestroyStream(stream));
    if (aclrtResetDevice(device) != ACL_SUCCESS) {
        printf("Reset device failed\n");
    }
    if (aclFinalize() != ACL_SUCCESS) {
        printf("Finalize acl failed\n");
    }
}

void DestroyTensor(aclTensor *tensors[], void *devMem[], int32_t tensorCount)
{
    for (auto i = 0; i < tensorCount; i++) {
        if (!tensors[i]) {
            continue;
        }
        if (devMem[i]) {
            CHECK_ACL(aclrtFree(devMem[i]));
        }
        CHECK_ACL(aclDestroyTensor(tensors[i]));
    }
}

struct tensorInfo {
    int64_t *dims;
    int64_t dimCnt;
    aclDataType dtype;
    aclFormat fmt;
};

int64_t GetDataSize(struct tensorInfo *desc)
{
    if (!desc->dims) {
        return 0;
    }
    int64_t size = 1;
    for (auto i = 0; i < desc->dimCnt; i++) {
        size *= desc->dims[i];
    }
    return size * sizeof(float);
}

static bool CompareResult(const void *outputData, int64_t outSize, std::string goldenName)
{
    void *goldenData;
    CHECK_ACL(aclrtMallocHost((void **)(&goldenData), outSize));
    size_t goldenSize = outSize;
    bool ret = ReadFile("../output/golden_" + goldenName + ".bin", goldenSize, goldenData, goldenSize);
    if (ret) {
        printf("ReadFile golden_%s.bin success!\n", goldenName.c_str());
    } else {
        CHECK_ACL(aclrtFreeHost(goldenData));
        return false;
    }
    constexpr float EPS = 1e-4;
    int64_t wrongNum = 0;

    for (auto i = 0; i < outSize / sizeof(float); i++) {
        float a = (reinterpret_cast<const float *>(outputData))[i];
        float b = (reinterpret_cast<const float *>(goldenData))[i];
        float ae = std::abs(a - b);
        float re = ae / abs(b);
        if (ae > EPS && re > EPS) {
            printf("CompareResult golden_output_%s.bin failed output is %lf, golden is %lf\n", goldenName.c_str(), a,
            b);
            wrongNum++;
        }
    }
    CHECK_ACL(aclrtFreeHost(goldenData));

    if (wrongNum != 0) {
        return false;
    } else {
        printf("CompareResult golden_output_%s.bin success\n", goldenName.c_str());
        return true;
    }
}

int32_t main(void)
{
    int64_t inputX[] = {2, 32, 16};
    int64_t inputDy[] = {2, 32, 16};
    int64_t inputVariance[] = {2, 32};
    int64_t inputMean[] = {2, 32};
    int64_t inputGamma[] = {16};
    int64_t outputPdX[] = {2, 32, 16};
    int64_t resForGamma[] = {2, 32, 16};
    struct tensorInfo tensorDesc[] = {{inputX, 3, ACL_FLOAT, ACL_FORMAT_ND},
                                      {inputDy, 3, ACL_FLOAT, ACL_FORMAT_ND},
                                      {inputVariance, 2, ACL_FLOAT, ACL_FORMAT_ND},
                                      {inputMean, 2, ACL_FLOAT, ACL_FORMAT_ND},
                                      {inputGamma, 1, ACL_FLOAT, ACL_FORMAT_ND},
                                      {outputPdX, 3, ACL_FLOAT, ACL_FORMAT_ND},
                                      {resForGamma, 3, ACL_FLOAT, ACL_FORMAT_ND},
                                     };
    std::string ParamNames[] = {
        "inputX",
        "inputDy",
        "inputVariance",
        "inputMean",
        "inputGamma",
        "outputPdX",
        "resForGamma",
    };
    aclrtStream stream = CreateStream(0);
    if (stream == nullptr) {
        return -1;
    }
    int32_t tensorCount = sizeof(tensorDesc) / sizeof(struct tensorInfo);
    aclTensor *tensors[tensorCount];
    void *devMem[tensorCount];
    for (auto i = 0; i < tensorCount; i++) {
        void *data;
        struct tensorInfo *info = &(tensorDesc[i]);
        int64_t size = GetDataSize(info);
        if (size == 0) {
            tensors[i] = nullptr;
            devMem[i] = nullptr;
            continue;
        }
        CHECK_ACL(aclrtMalloc(&data, size, ACL_MEM_MALLOC_HUGE_FIRST));
        // read input
        if (i < 5) {
            size_t inputSize = size;
            void *dataHost;
            CHECK_ACL(aclrtMallocHost((void **)(&dataHost), inputSize));
            ReadFile("../input/input_" + ParamNames[i] + ".bin", inputSize, dataHost, inputSize);
            CHECK_ACL(aclrtMemcpy(data, size, dataHost, size, ACL_MEMCPY_HOST_TO_DEVICE));
            CHECK_ACL(aclrtFreeHost(dataHost));
        }
        devMem[i] = data;
        tensors[i] =
            aclCreateTensor(info->dims, info->dimCnt, info->dtype, nullptr, 0, info->fmt, info->dims, info->dimCnt, data);
    }

    size_t workspaceSize = 0;
    aclOpExecutor *handle;
    int32_t ret;
    ret = aclnnLayernormGradCustomGetWorkspaceSize(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
        tensors[5], tensors[6], &workspaceSize, &handle);
    if (ret != ACL_SUCCESS) {
        printf("aclnnLayernormGradCustomGetWorkspaceSize failed. error code is %u\n", ret);
        DestroyTensor(tensors, devMem, tensorCount);
        DestroyStream(stream, 0);
        return ret;
    }
    printf("aclnnLayernormGradCustomGetWorkspaceSize ret %u workspace size %lu\n", ret, workspaceSize);
    void *workspace = nullptr;
    if (workspaceSize != 0) {
        CHECK_ACL(aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
    }
    ret = aclnnLayernormGradCustom(workspace, workspaceSize, handle, stream);
    printf("aclnnLayernormGradCustom ret %u\n", ret);
    CHECK_ACL(aclrtSynchronizeStream(stream));

    uint8_t *outputPdXHost, *resForGammaHost;
    int64_t outputPdXHostSize = GetDataSize(&(tensorDesc[5]));
    int64_t resForGammaHostSize = GetDataSize(&(tensorDesc[6]));

    CHECK_ACL(aclrtMallocHost((void **)(&outputPdXHost), outputPdXHostSize));
    CHECK_ACL(aclrtMallocHost((void **)(&resForGammaHost), resForGammaHostSize));

    CHECK_ACL(aclrtMemcpy(outputPdXHost, outputPdXHostSize, devMem[5], outputPdXHostSize, ACL_MEMCPY_DEVICE_TO_HOST));
    CHECK_ACL(aclrtMemcpy(resForGammaHost, resForGammaHostSize, devMem[6], resForGammaHostSize, ACL_MEMCPY_DEVICE_TO_HOST));

    WriteFile("../output/output_outputPdX.bin", outputPdXHost, outputPdXHostSize);
    WriteFile("../output/output_resForGamma.bin", resForGammaHost, resForGammaHostSize);

    bool goldenResult = true;
    goldenResult &= CompareResult(outputPdXHost, outputPdXHostSize, ParamNames[5]);
    goldenResult &= CompareResult(resForGammaHost, resForGammaHostSize, ParamNames[6]);
    if (goldenResult) {
        printf("test pass!\n");
    } else {
        printf("test failed!\n");
    }
    if (workspaceSize != 0) {
        CHECK_ACL(aclrtFree(workspace));
    }
    CHECK_ACL(aclrtFreeHost(outputPdXHost));
    CHECK_ACL(aclrtFreeHost(resForGammaHost));

    DestroyTensor(tensors, devMem, tensorCount);
    DestroyStream(stream, 0);
    return 0;
}
